{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mf3kOv1YMB5y"
      },
      "source": [
        "# Adaptive warmup length\n",
        "\n",
        "This notebook experiments the use of adaptively terminating the warmup length\n",
        "\n",
        "Copyright 2021 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-rOdskBSMfQN"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dSA8AyBbqPsK"
      },
      "outputs": [],
      "source": [
        "# Import tf first to enable eager mode.\n",
        "import tensorflow as tf\n",
        "tf.executing_eagerly()\n",
        "\n",
        "import numpy as np\n",
        "from matplotlib.pyplot import *\n",
        "%config InlineBackend.figure_format = 'retina'\n",
        "matplotlib.pyplot.style.use(\"dark_background\")\n",
        "\n",
        "import jax\n",
        "from jax import random\n",
        "from jax import numpy as jnp\n",
        "\n",
        "from colabtools import adhoc_import\n",
        "\n",
        "# import tensforflow_datasets\n",
        "from inference_gym import using_jax as gym\n",
        "\n",
        "from tensorflow_probability.spinoffs.fun_mc import using_jax as fun_mcmc\n",
        "\n",
        "\n",
        "# import tensorflow as tf\n",
        "from tensorflow_probability.python.internal import prefer_static as ps\n",
        "from tensorflow_probability.python.internal import unnest\n",
        "\n",
        "\n",
        "import tensorflow_probability as _tfp\n",
        "tfp = _tfp.substrates.jax\n",
        "tfd = tfp.distributions\n",
        "tfb = tfp.bijectors\n",
        "\n",
        "tfp_np = _tfp.substrates.numpy\n",
        "tfd_np = tfp_np.distributions \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OfipSLXY0SXt"
      },
      "outputs": [],
      "source": [
        "# tested options: Bananas. other: GermanCredit, EightSchools\n",
        "problem_name = 'EightSchools'\n",
        "\n",
        "if problem_name == 'Bananas':\n",
        "  target = gym.targets.VectorModel(gym.targets.Banana(),\n",
        "                                    flatten_sample_transformations=True)\n",
        "  num_dimensions = target.event_shape[0]  \n",
        "  init_step_size = 1.\n",
        "\n",
        "if problem_name == 'GermanCredit':\n",
        "  # This problem seems to require that we load TF datasets first.\n",
        "  import tensorflow_datasets\n",
        "  target = gym.targets.VectorModel(gym.targets.GermanCreditNumericSparseLogisticRegression(),\n",
        "                                    flatten_sample_transformations=True)\n",
        "  num_dimensions = target.event_shape[0]\n",
        "  init_step_size = 0.02\n",
        "\n",
        "if problem_name == 'Brownian':\n",
        "  target = gym.targets.BrownianMotionMissingMiddleObservations()\n",
        "  target = gym.targets.VectorModel(target,\n",
        "                                    flatten_sample_transformations = True)\n",
        "  num_dimensions = target.event_shape[0]\n",
        "  init_step_size = 0.01\n",
        "\n",
        "# NOTE: this loads the centered parameterization... (use code below to\n",
        "# get non-centered parameterization). This code is still useful to get\n",
        "# the correct parameter values.\n",
        "if problem_name == 'EightSchools':\n",
        "  target_raw = gym.targets.EightSchools()  # store raw to examine doc.\n",
        "  target = gym.targets.VectorModel(target_raw,\n",
        "                                    flatten_sample_transformations = True)\n",
        "  num_dimensions = target.event_shape[0]\n",
        "  init_step_size = 1\n",
        "\n",
        "\n",
        "def target_log_prob_fn(x):\n",
        "  \"\"\"Unnormalized, unconstrained target density.\n",
        "\n",
        "  This is a thin wrapper that applies the default bijectors so that we can\n",
        "  ignore any constraints.\n",
        "  \"\"\"\n",
        "  y = target.default_event_space_bijector(x)\n",
        "  fldj = target.default_event_space_bijector.forward_log_det_jacobian(x)\n",
        "  return target.unnormalized_log_prob(y) + fldj\n",
        "\n",
        "if problem_name == 'Bananas':\n",
        "  offset = 2\n",
        "  def initialize (shape, key = random.PRNGKey(37272709)):\n",
        "    return 3 * random.normal(key, shape + (num_dimensions,)) + offset\n",
        "\n",
        "if problem_name == 'GermanCredit':\n",
        "  offset = 0.1\n",
        "  def initialize (shape, key = random.PRNGKey(37272709)):\n",
        "    return 0.5 * random.normal(key, shape + (num_dimensions,)) + offset\n",
        "\n",
        "# Using underdispersed initis can show case problems with our diagnostics.\n",
        "underdispered = False\n",
        "if problem_name == 'EightSchools':\n",
        "  if underdispered:\n",
        "    offset = 0.0\n",
        "    def initialize (shape, key = random.PRNGKey(37272709)):\n",
        "      return 1 * random.normal(key, shape + (num_dimensions,)) + offset\n",
        "      # return 3 * random.normal(key, shape + (num_dimensions,)) + offset\n",
        "  else:\n",
        "    def initialize (shape, key = random.PRNGKey(37272709)):\n",
        "     prior_scale = jnp.append(jnp.array([10., 1.]), jnp.repeat(1., 8))\n",
        "     prior_offset = jnp.append(jnp.array([0., 5.]), jnp.repeat(0., 8))\n",
        "     return prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_offset\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jSN6gbVJd_ou"
      },
      "outputs": [],
      "source": [
        "target_raw = gym.targets.EightSchools()\n",
        "# target = gym.targets.GermanCreditNumericSparseLogisticRegression()\n",
        "# print(super(type(target_raw), target_raw).__doc__)\n",
        "print(type(target_raw).__doc__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1k2X3Z1fsGdp"
      },
      "source": [
        "Specify eight school model using a non-centered parameterization.\n",
        "\\begin{eqnarray*}\n",
        "\\mu \u0026 \\sim \u0026 N(0, 10)  \\\\\n",
        "\\log \\tau \u0026 \\sim \u0026 N(5, 1) \\\\\n",
        "\\eta \u0026 \\sim \u0026 N(0, 1) \\\\\n",
        "\\theta \u0026 = \u0026 \\mu + \\tau \\theta \\\\\n",
        "y \u0026 \\sim \u0026 N(\\theta, \\sigma)\n",
        "\\end{eqnarray*}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A8tZpy0U8A2r"
      },
      "outputs": [],
      "source": [
        "if problem_name == \"EightSchools\":\n",
        "  num_schools = 8\n",
        "  y = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype = np.float32)\n",
        "  sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype = np.float32)\n",
        "\n",
        "  # NOTE: the reinterpreted batch dimension specifies the dimension of\n",
        "  # each indepdent variable, here the school.\n",
        "  model = tfd.JointDistributionSequential([\n",
        "    tfd.Normal(loc = 0., scale = 10., name = \"mu\"),\n",
        "    tfd.Normal(loc = 5., scale = 1., name = \"log_tau\"),\n",
        "    tfd.Independent(tfd.Normal(loc = jnp.zeros(num_schools),\n",
        "                               scale = jnp.ones(num_schools),\n",
        "                               name = \"eta\"),\n",
        "                    reinterpreted_batch_ndims = 1),\n",
        "    lambda eta, log_tau, mu: (\n",
        "        tfd.Independent(tfd.Normal(loc = (mu[..., jnp.newaxis] +\n",
        "                                        jnp.exp(log_tau[..., jnp.newaxis]) *\n",
        "                                        eta),\n",
        "                                   scale = sigma),\n",
        "                        name = \"y\",\n",
        "                        reinterpreted_batch_ndims = 1))\n",
        "  ])\n",
        "\n",
        "  def target_log_prob_fn(x):\n",
        "    mu = x[:, 0]\n",
        "    log_tau = x[:, 1]\n",
        "    eta = x[:, 2:10]\n",
        "    return model.log_prob((mu, log_tau, eta, y))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X50f-fxJ5pHv"
      },
      "outputs": [],
      "source": [
        "initial_state = initialize((4,), key = jax.random.PRNGKey(1))\n",
        "evaluated_density = target_log_prob_fn(initial_state)\n",
        "evaluated_density"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PeR-wW64Jzo2"
      },
      "outputs": [],
      "source": [
        "# Get some estimates of the mean and variance.\n",
        "# WARNING: for EightSchool problem, the correct value is for the centered\n",
        "# parameterization.\n",
        "try:\n",
        "  mean_est = target.sample_transformations['identity'].ground_truth_mean\n",
        "except:\n",
        "  print('no ground truth mean')\n",
        "  mean_est = (result.all_states[num_warmup:, :]).mean(0).mean(0)\n",
        "try:\n",
        "  var_est = target.sample_transformations['identity'].ground_truth_standard_deviation**2\n",
        "except:\n",
        "  print('no ground truth std dev')\n",
        "  var_est = ((result.all_states[num_warmup:, :]**2).mean(0).mean(0) -\n",
        "             mean_est**2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6DZTdfm3gq-x"
      },
      "outputs": [],
      "source": [
        "print(mean_est)\n",
        "print(var_est)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7l98WlSo0ZR3"
      },
      "outputs": [],
      "source": [
        "# Follow procedure described in source code for potential scale reduction.\n",
        "# NOTE: some of the tf argument need to be adjusted (e.g. keepdims = False,\n",
        "# instead of True). Not quite sure why.\n",
        "# QUESTION: can these be accessed as internal functions of tf?\n",
        "# TODO: following Pavel's example, rewrite this without using tf.\n",
        "# TODO: add error message when the number of samples is less than 2.\n",
        "\n",
        "# REMARK: this function doesn't seem to work, returns NaN.\n",
        "# As a result, can only use _reduce_variance with biased =  False.\n",
        "def _axis_size(x, axis = None):\n",
        "  \"\"\"Get number of elements of `x` in `axis`, as type `x.dtype`.\"\"\"\n",
        "  if axis is None:\n",
        "    return ps.cast(ps.size(x), x.dtype)\n",
        "  return ps.cast(\n",
        "      ps.reduce_prod(\n",
        "          ps.gather(ps.shape(x), axis)), x.dtype)\n",
        "\n",
        "def _reduce_variance(x, axis=None, biased=True, keepdims=False):\n",
        "  with tf.name_scope('reduce_variance'):\n",
        "    x = tf.convert_to_tensor(x, name='x')\n",
        "    mean = tf.reduce_mean(x, axis=axis, keepdims=True)\n",
        "    biased_var = tf.reduce_mean(\n",
        "        tf.math.squared_difference(x, mean), axis=axis, keepdims=keepdims)\n",
        "    if biased:\n",
        "      return biased_var\n",
        "    n = _axis_size(x, axis)\n",
        "    return (n / (n - 1.)) * biased_var\n",
        "\n",
        "def nested_rhat(result_state, num_super_chain):\n",
        "  used_samples = result_state.shape[0]\n",
        "  num_sub_chains = result_state.shape[1] // num_super_chains\n",
        "  num_dimensions = result_state.shape[2]\n",
        "\n",
        "  chain_states = result_state.reshape(used_samples, -1, num_sub_chains,\n",
        "                                      num_dimensions)\n",
        "\n",
        "  state = tf.convert_to_tensor(chain_states, name = 'state')\n",
        "  mean_chain = tf.reduce_mean(state, axis = 0)\n",
        "  mean_super_chain = tf.reduce_mean(state, axis = [0, 2])\n",
        "  variance_chain = _reduce_variance(state, axis = 0, biased = False)\n",
        "  variance_super_chain = _reduce_variance(mean_chain, axis = 1, biased = False) \\\n",
        "     + tf.reduce_mean(variance_chain, axis = 1)\n",
        "\n",
        "  W = tf.reduce_mean(variance_super_chain, axis = 0)\n",
        "  B = _reduce_variance(mean_super_chain, axis = 0, biased = False)\n",
        "\n",
        "  return tf.sqrt((W + B) / W)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_5oZPsPp0ql9"
      },
      "outputs": [],
      "source": [
        "def forge_chain (target_rhat, warmup_window_size, kernel_cold, initial_state,\n",
        "                 max_num_steps, seed, monitor = False,\n",
        "                 use_nested_rhat = True, use_log_joint = False,\n",
        "                 num_super_chains = 4):\n",
        "  # store certain variables\n",
        "  rhat_forge = np.array([])\n",
        "  warmup_is_acceptable = False\n",
        "  store_results = []\n",
        "\n",
        "  warmup_iteration = 0\n",
        "\n",
        "  current_state = initial_state\n",
        "  final_kernel_args = None\n",
        "\n",
        "  while (not warmup_is_acceptable and warmup_iteration \u003c= max_num_steps):\n",
        "    warmup_iteration += 1\n",
        "\n",
        "    # 1) Run MCMC on short warmup window\n",
        "    result_cold, target_log_prob, final_kernel_args = tfp.mcmc.sample_chain(\n",
        "        num_results = warmup_window_size,\n",
        "        current_state = current_state,\n",
        "        kernel = kernel_cold,\n",
        "        previous_kernel_results = final_kernel_args,\n",
        "        seed = seed,\n",
        "        trace_fn = lambda _, pkr: unnest.get_innermost(pkr, 'target_log_prob'),\n",
        "        return_final_kernel_results = True)\n",
        "\n",
        "    if warmup_iteration == 1:\n",
        "      store_results = result_cold\n",
        "    else : \n",
        "      store_results = np.append(store_results, result_cold, axis = 0)\n",
        "\n",
        "    current_state = result_cold[-1]\n",
        "\n",
        "    # 2) Check if warmup is acceptable\n",
        "    if use_nested_rhat:\n",
        "      if use_log_joint:\n",
        "        shape_lp = target_log_prob.shape\n",
        "        rhat_warmup = nested_rhat(target_log_prob.reshape(shape_lp[0], shape_lp[1], 1),\n",
        "                                  num_super_chains)\n",
        "      else:\n",
        "        rhat_warmup = max(nested_rhat(result_cold, num_super_chains))\n",
        "    else:\n",
        "      if use_log_joint:\n",
        "        rhat_warmup = tfp.mcmc.potential_scale_reduction(target_log_prob)\n",
        "      else:\n",
        "        rhat_warmup = max(tfp.mcmc.potential_scale_reduction(result_cold))\n",
        "\n",
        "    if rhat_warmup \u003c target_rhat: warmup_is_acceptable = True\n",
        "\n",
        "    save_values = True\n",
        "    if save_values:\n",
        "      rhat_forge = np.append(rhat_forge, rhat_warmup)\n",
        "    # While loop ends\n",
        "\n",
        "  return store_results, final_kernel_args, rhat_forge"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y9lV8LLeCp4b"
      },
      "outputs": [],
      "source": [
        "def mc_est_warm(x, axis = 0):\n",
        "  \"\"\" compute running average without discarding half of the samples.\"\"\"\n",
        "  return np.cumsum(x, axis) / np.arange(1, x.shape[0] + 1).reshape([-1] + [1] * (len(x.shape) - 1))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F9Fkaa1j1UqA"
      },
      "outputs": [],
      "source": [
        "# Set up adaptive warmup scheme\n",
        "init_step_size = 1 # 0.1  # CHECK: how should this be set?\n",
        "max_warmup_length = 1000  # CHECK: how should this be set?\n",
        "\n",
        "# define kernel using most recent step size\n",
        "kernel_cold = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)\n",
        "kernel_cold = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel_cold, max_warmup_length)\n",
        "kernel_cold = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
        "      kernel_cold, max_warmup_length, target_accept_prob = 0.75,\n",
        "      reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)\n",
        "# kernel_cold = tfp.mcmc.NoUTurnSampler(target_log_prob_fn, init_step_size)\n",
        "\n",
        "# kernel for sampling phase\n",
        "kernel_warm = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)\n",
        "kernel_warm = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel_warm, 0)\n",
        "kernel_warm = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
        "      kernel_warm, 0, target_accept_prob = 0.75,\n",
        "      reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YIi2rGY32Wa2"
      },
      "outputs": [],
      "source": [
        "# initial_state = initialize((4,), key = jax.random.PRNGKey(1))\n",
        "# result_short = tfp.mcmc.sample_chain(\n",
        "#     num_results = 2000, current_state = initial_state, kernel = kernel_cold,\n",
        "#     seed = random.PRNGKey(1954))\n",
        "\n",
        "# states = result_short.all_states[1000:2000, :, :]\n",
        "# print(\"rhat:\", tfp.mcmc.potential_scale_reduction(states, independent_chain_ndims = 1).T)\n",
        "# print(\"mean:\", np.mean(states.mean(0), axis = 0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "29Lo2TwfSgH1"
      },
      "source": [
        "Intuitively, if we warm up our chains well and they forget their initial point, we expect each chain to generate independent samples. Hence, if our goal is to generate an effective sample size of 100 (and we're committed to properly warming up our chains), then we should be able to reach our goal using 100 chains, each with one sampling iteration. One way to check this is to do the adaptive warmup, generate one sample, and then check if the squared error of the Monte Carlo estimate is about $\\mathrm{Var}(\\theta) / n_\\mathrm{chains}$.\n",
        "\n",
        "One issue is that our estimated squared error can be quite noisy. Suppose we run a model fit $M$ times, each generating a Monte Carlo estimator $\\hat \\theta^{(m)}$. Leveraring the precise estimator from the inference gym, we then compute the squared error as\n",
        "$$\n",
        "  (\\hat \\theta^{(m)} - \\theta^*)^2.\n",
        "$$\n",
        "Let's investigate the property of this estimator.\n",
        "\n",
        "Under a CLT, meaning $m$ is relatively large, we get \n",
        "$$\n",
        "  \\hat \\theta^{(m)} \\overset{\\mathrm{approx}}{\\sim} \\mathrm{Normal} \\left(\\theta^*, \\frac{\\sigma^2}{N} \\right),\n",
        "$$\n",
        "where $N$ is the effective sample size. Then\n",
        "$$\n",
        "  (\\hat \\theta^{(m)} - \\theta^*)^2 \\overset{\\mathrm{approx}}{\\sim} \\frac{\\sigma^2}{N} \\chi^2_M,\n",
        "$$\n",
        "where we have $M$ degrees of freedom, since $\\theta^*$ is not estimated using our sample. The estimator\n",
        "$$\n",
        "\\hat \\tau^2 = \\frac{1}{M} \\sum_i (\\hat \\theta^{(m)} - \\theta^*)^2\n",
        "$$\n",
        "has the following properties:\n",
        "$$\n",
        "\\mathbb E \\hat \\tau^2 \\approx \\sigma^2 / N\n",
        "$$\n",
        "and\n",
        "$$\n",
        "\\mathrm{Var} \\hat \\tau^2 \\approx \\frac{2 \\sigma^2}{MN}.\n",
        "$$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-m6jxH8iSyXo"
      },
      "outputs": [],
      "source": [
        "# NOTE: need to control the experiment to avoid exhausting the memory\n",
        "num_chains_array = [16, 32, 64, 128, 256]  # (recommended for Bananas)\n",
        "# num_chains_array = [256, 512]  # [16, 32, 64, 128, 256]\n",
        "# num_chains_array = [16, 20, 24, 28, 32, 64, 128]\n",
        "# num_chains_array = [512]\n",
        "# num_chains_array = [2048]  # [32, 64]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yryGMNFI00ko"
      },
      "outputs": [],
      "source": [
        "\n",
        "num_super_chains = 4  # 4\n",
        "target_rhat = 1.01\n",
        "warmup_window_size = 50  # 25  # 300\n",
        "max_num_steps = 1000 // warmup_window_size  # 1000  # 1\n",
        "mc_err_mean = np.array([])\n",
        "mc_err_median = np.array([])\n",
        "mc_err_sd = np.array([])\n",
        "length_warmup_mean = np.array([])\n",
        "length_warmup_sd = np.array([])\n",
        "nested_rhat_store = np.array([])\n",
        "\n",
        "parameter_index = 0  # parameter of interest\n",
        "\n",
        "# TODO: parallelize this while still allowing each run to have a different\n",
        "# warmup length.\n",
        "for num_chains_short in num_chains_array:\n",
        "\n",
        "  length_warmup = np.array([])\n",
        "  mc_err = np.array([])\n",
        "\n",
        "  for seed in jax.random.split(jax.random.PRNGKey(1), 10):  # 10  # 30\n",
        "    initial_state = initialize((num_super_chains,), key = seed)\n",
        "    initial_state = np.repeat(initial_state, num_chains_short // num_super_chains,\n",
        "                              axis = 0)\n",
        "\n",
        "    # 1) Warmup phase\n",
        "    result_cold, final_kernel_args, rhat_forge = \\\n",
        "      forge_chain(target_rhat = target_rhat,\n",
        "                  warmup_window_size = warmup_window_size,\n",
        "                  kernel_cold = kernel_cold,\n",
        "                  initial_state = initial_state,\n",
        "                  max_num_steps = max_num_steps,\n",
        "                  seed = seed + 1, monitor = False,\n",
        "                  use_nested_rhat = True,\n",
        "                  use_log_joint = False)\n",
        "\n",
        "    length_warmup = np.append(length_warmup,\n",
        "                              len(rhat_forge) * warmup_window_size)\n",
        "  \n",
        "   # 2) Sampling phase\n",
        "    current_state = result_cold[-1]\n",
        "\n",
        "    result_warm = tfp.mcmc.sample_chain(\n",
        "        num_results = 50,\n",
        "        current_state = current_state,\n",
        "        kernel = kernel_warm,\n",
        "        previous_kernel_results = final_kernel_args,\n",
        "        seed = seed + 2,\n",
        "        return_final_kernel_results = None, trace_fn = None)\n",
        "\n",
        "    # Compute error based on first iteration of the sampling phase.\n",
        "    mc_err = np.append(mc_err,\n",
        "                       np.square(result_warm[0, :, parameter_index].mean()\n",
        "                                 - mean_est[parameter_index]))\n",
        "    \n",
        "    # Store nested Rhat computed using first 5 sampling iterations\n",
        "    nested_rhat_store = np.append(nested_rhat_store,\n",
        "    nested_rhat(result_warm[0:5, :, :], num_super_chain = num_super_chains)[0])\n",
        "\n",
        "    # END seed for loop\n",
        "  mc_err_mean = np.append(mc_err_mean, mc_err.mean())\n",
        "  mc_err_median = np.append(mc_err_median, np.median(mc_err))\n",
        "  mc_err_sd = np.append(mc_err_sd, mc_err.std())\n",
        "  length_warmup_mean = np.append(length_warmup_mean, length_warmup.mean())\n",
        "  length_warmup_sd = np.append(length_warmup_sd, length_warmup.std())\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2faFeLuLd8A3"
      },
      "outputs": [],
      "source": [
        "print(result_warm[0, :, parameter_index].mean())\n",
        "print(mean_est)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gpgbyx5r4Lrx"
      },
      "outputs": [],
      "source": [
        "# # Compute school effect\n",
        "# theta = (result_warm[:, :, 0, jnp.newaxis] +\n",
        "#   jnp.exp(result_warm[:, :, 1, jnp.newaxis]) * result_warm[:, :, 2:10])\n",
        "\n",
        "# # print(\"Our estimates:\", mc_est_warm(result_warm.mean(0))[0:2])\n",
        "# print(\"Our estimates:\", np.mean(result_warm.mean(0), axis = 0)[0:2])\n",
        "# print(\"Correct values:\", mean_est[0:2])\n",
        "\n",
        "# print(\"Our school estimates:\", theta.mean(axis = [0, 1]))\n",
        "# print(\"True estimates:\", mean_est[2:10])\n",
        "\n",
        "num_chains_short\n",
        "print(\"length warmup: \", length_warmup)\n",
        "print(\"mc_err\", mc_err)\n",
        "print(\"nested-Rhat\", nested_rhat_store)\n",
        "# print(\"mc_mean (excluding outlier):\", mc_err[1:10].mean())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lT5Dd7cgIV9w"
      },
      "outputs": [],
      "source": [
        "print(\"mc_err_mean:\", mc_err_mean, \"+/-\", mc_err_sd)\n",
        "print(\"mc_err_median:\", mc_err_median)\n",
        "print(\"warmup length:\", length_warmup_mean, \"+/-\", length_warmup_sd)\n",
        "print(var_est[parameter_index] / np.array(num_chains_array))\n",
        "result_warm.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xBO_4KRV7dTa"
      },
      "outputs": [],
      "source": [
        "# For experiments that exhaust memory, enter results from previous kernels\n",
        "enter_manually = False  # True\n",
        "if enter_manually:\n",
        "  num_chains_array = np.array([16, 32, 64, 128, 256, 512])\n",
        "  mc_err_mean = np.array([1.517, 0.877, 0.474, 0.388, 0.139, 0.576])\n",
        "  mc_err_sd = np.array([2.45, 0.693, 0.629, 0.791, 0.14, 1.304])\n",
        "  mc_err_median = np.array([0.476, 0.693, 0.173, 0.0460, 0.075, 0.054])\n",
        "  length_warmup_mean = np.array([1025., 1025., 722, 460, 437, 400])\n",
        "  length_warmup_sd = np.array([0., 0., 160, 124, 127, 115])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p6N_8AdG82Tr"
      },
      "outputs": [],
      "source": [
        "print(mc_err_mean)\n",
        "print(num_chains_array)\n",
        "print(3 * mc_err_sd / np.sqrt(10))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H4hquc7sQjQt"
      },
      "outputs": [],
      "source": [
        "figure(figsize = [6, 6])\n",
        "errorbar(x = num_chains_array, y = mc_err_mean, yerr = 3 * mc_err_sd / np.sqrt(10), label = 'Observed mean')\n",
        "# plot(num_chains_array, mc_err_median, label = 'Observed median')\n",
        "plot(num_chains_array, var_est[parameter_index] / np.array(num_chains_array), linestyle = '--', label = 'Expected')\n",
        "legend(loc = 'best')\n",
        "xlabel(\"Number of chains\")\n",
        "ylabel(\"Squared error\")\n",
        "show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H5jc9roqA4Ab"
      },
      "outputs": [],
      "source": [
        "figure(figsize = [6, 6])\n",
        "errorbar(x = num_chains_array, y = length_warmup_mean, yerr = length_warmup_sd)\n",
        "plot(num_chains_array, length_warmup_mean, 'o')\n",
        "ylabel(\"Warmup length\")\n",
        "xlabel(\"Number of chains\")\n",
        "show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yod3AS2uA941"
      },
      "source": [
        "### Further analysis for last run"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3yFCDOiZCE3i"
      },
      "outputs": [],
      "source": [
        "# Takes at least 10 iterations to estimate ESS?\n",
        "ess_warm = np.sum(tfp.mcmc.effective_sample_size(result_warm[1:10, :, :]))\n",
        "result_warm.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VQyjBc_qBAGl"
      },
      "outputs": [],
      "source": [
        "print(\"nested_rhat:\", nested_rhat(result_warm[0:5, :, :], num_super_chain = 4))\n",
        "print(\"Parameter index: \", parameter_index)\n",
        "print(\"Number of iterations: \", result_warm.shape)\n",
        "print(\"Ess: \", ess_warm)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JuSjBScgC1Be"
      },
      "outputs": [],
      "source": [
        "print(mean_est[0])\n",
        "print(np.mean(result_warm[0, :, 0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-_lnnryuHCnT"
      },
      "outputs": [],
      "source": [
        "print(np.mean(result_warm[0:5, :, 0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FG6L9uJyEKye"
      },
      "outputs": [],
      "source": [
        "num_samples_plot = 4  # target_iter_mean\n",
        "plot(result_warm[1:5, 256:512, 1])\n",
        "show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3XzcM7nKIQgP"
      },
      "outputs": [],
      "source": [
        "# compute mean of super chains and between super chain variance (using first 5 iterations)\n",
        "result_state = result_warm[0, :, :]\n",
        "\n",
        "used_samples = 1 # result_state.shape[0]\n",
        "num_sub_chains = result_state.shape[0] // num_super_chains\n",
        "num_dimensions = result_state.shape[1]\n",
        "\n",
        "chain_states = result_state.reshape(used_samples, -1, num_sub_chains,\n",
        "                                    num_dimensions)\n",
        "\n",
        "mean_superchain = np.mean(chain_states, axis = [0, 2])\n",
        "\n",
        "B = mean_superchain.var(0, ddof = 1)\n",
        "V = result_state.var(axis = [0, 1], ddof = 1)\n",
        "\n",
        "print(used_samples)\n",
        "print(result_state.shape)\n",
        "print(chain_states.shape)\n",
        "print(mean_superchain.shape, mean_superchain[:, 0])\n",
        "print(B)\n",
        "print(V)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XvksiJbzOiY3"
      },
      "outputs": [],
      "source": [
        "np.unique(initial_state[:, 0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oiwHCtkmMMRG"
      },
      "outputs": [],
      "source": [
        "m = np.mean([2.26, 1.78, 2.12, 2.01])\n",
        "S = (2.26 - m)**2 + (1.78 - m)**2 + (2.12 - m)**2 + (2.01 - m)**2\n",
        "S / 3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9FJanNXr5s2g"
      },
      "source": [
        "# Draft Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dJjlslmOTOzv"
      },
      "outputs": [],
      "source": [
        "mc_err_mean_adaptive = mc_err_mean\n",
        "mc_err_sd_adaptive = mc_err_sd"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ENVzgBPd01Lt"
      },
      "outputs": [],
      "source": [
        "print(\"Expected squared error:\", var_est / num_chains_short)\n",
        "print(\"Mean warmup length:\", length_warmup.mean(), \"+/-\", length_warmup.std())\n",
        "print(\"Mean Monte Carlo squared error:\", mc_err.mean(), \"+/-\", mc_err.std())\n",
        "print(\"Median Monte Carlo squared error:\", np.median(mc_err))\n",
        "\n",
        "# hist(result.all_states[:, :, 0].flatten(), 30, log=True)\n",
        "hist(mc_err, 30)\n",
        "show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SXTPFp8tB4Vx"
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "current_state = result_cold[-1]\n",
        "\n",
        "result_warm = tfp.mcmc.sample_chain(\n",
        "      num_results = 5,\n",
        "      current_state = current_state,\n",
        "      kernel = kernel_warm,\n",
        "      previous_kernel_results = final_kernel_args,\n",
        "      seed = random.PRNGKey(100001),\n",
        "      return_final_kernel_results = None, trace_fn = None)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YcfPNyhuOanD"
      },
      "outputs": [],
      "source": [
        "\n",
        "target_precision = var_est / 100\n",
        "mc_est_warm(result_warm.mean(1))[0, 0]"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Adaptive_warmup.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1bMcVyNvdA2_a66ZsvJxPSbtVYlA2ZAVe",
          "timestamp": 1632344166394
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
